{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "4Qs9wmH6Jwt5"
      },
      "source": [
        "# Predicting Links with Graph Neural Networks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6muoc6CEld7m",
        "outputId": "8e544430-ab0f-475f-eb1d-6b0e302c8b4d"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "d:\\Programs\\Anaconda\\envs\\book\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
            "  from .autonotebook import tqdm as notebook_tqdm\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "!pip install -q torch-scatter~=2.1.0 torch-sparse~=0.6.16 torch-cluster~=1.6.0 torch-spline-conv~=1.2.1 torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
        "\n",
        "torch.manual_seed(0)\n",
        "torch.cuda.manual_seed(0)\n",
        "torch.cuda.manual_seed_all(0)\n",
        "torch.backends.cudnn.deterministic = True\n",
        "torch.backends.cudnn.benchmark = False"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Graph Autoencoder (VAE) and Variational Graph Autoencoder (VGAE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EgqVUSiulquq",
        "outputId": "96b273c1-9832-47e0-bde2-bc8d764c5e22"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "np.random.seed(0)\n",
        "import torch\n",
        "torch.manual_seed(0)\n",
        "import matplotlib.pyplot as plt\n",
        "import torch_geometric.transforms as T\n",
        "from torch_geometric.datasets import Planetoid\n",
        "\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "transform = T.Compose([\n",
        "    T.NormalizeFeatures(),\n",
        "    T.ToDevice(device),\n",
        "    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True, add_negative_train_samples=False),\n",
        "])\n",
        "\n",
        "dataset = Planetoid('.', name='Cora', transform=transform)\n",
        "\n",
        "train_data, val_data, test_data = dataset[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "Hc0ZiEn1lqkq"
      },
      "outputs": [],
      "source": [
        "from torch_geometric.nn import GCNConv, VGAE\n",
        "\n",
        "class Encoder(torch.nn.Module):\n",
        "    def __init__(self, dim_in, dim_out):\n",
        "        super().__init__()\n",
        "        self.conv1 = GCNConv(dim_in, 2 * dim_out)\n",
        "        self.conv_mu = GCNConv(2 * dim_out, dim_out)\n",
        "        self.conv_logstd = GCNConv(2 * dim_out, dim_out)\n",
        "\n",
        "    def forward(self, x, edge_index):\n",
        "        x = self.conv1(x, edge_index).relu()\n",
        "        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SaKaQxpYmC_5",
        "outputId": "4af5b67c-a3e5-4220-9ab9-29d3c2bed6fb"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch  0 | Loss: 3.4210 | Val AUC: 0.6772 | Val AP: 0.7110\n",
            "Epoch 50 | Loss: 1.3324 | Val AUC: 0.6593 | Val AP: 0.6922\n",
            "Epoch 100 | Loss: 1.1675 | Val AUC: 0.7366 | Val AP: 0.7298\n",
            "Epoch 150 | Loss: 1.1166 | Val AUC: 0.7480 | Val AP: 0.7514\n",
            "Epoch 200 | Loss: 1.0074 | Val AUC: 0.8390 | Val AP: 0.8395\n",
            "Epoch 250 | Loss: 0.9541 | Val AUC: 0.8794 | Val AP: 0.8797\n",
            "Epoch 300 | Loss: 0.9509 | Val AUC: 0.8833 | Val AP: 0.8845\n",
            "Test AUC: 0.8833 | Test AP 0.8845\n"
          ]
        }
      ],
      "source": [
        "model = VGAE(Encoder(dataset.num_features, 16)).to(device)\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
        "\n",
        "def train():\n",
        "    model.train()\n",
        "    optimizer.zero_grad()\n",
        "    z = model.encode(train_data.x, train_data.edge_index)\n",
        "    loss = model.recon_loss(z, train_data.pos_edge_label_index) + (1 / train_data.num_nodes) * model.kl_loss()\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    return float(loss)\n",
        "\n",
        "@torch.no_grad()\n",
        "def test(data):\n",
        "    model.eval()\n",
        "    z = model.encode(data.x, data.edge_index)\n",
        "    return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)\n",
        "\n",
        "for epoch in range(301):\n",
        "    loss = train()\n",
        "    val_auc, val_ap = test(test_data)\n",
        "    if epoch % 50 == 0:\n",
        "        print(f'Epoch {epoch:>2} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f} | Val AP: {val_ap:.4f}') \n",
        "\n",
        "test_auc, test_ap = test(test_data) \n",
        "print(f'Test AUC: {test_auc:.4f} | Test AP {test_ap:.4f}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "tensor([[0.8846, 0.5068, 0.7074,  ..., 0.5160, 0.8309, 0.8378],\n",
              "        [0.5068, 0.8741, 0.7878,  ..., 0.3900, 0.5367, 0.5495],\n",
              "        [0.7074, 0.7878, 0.8714,  ..., 0.4318, 0.7806, 0.7602],\n",
              "        ...,\n",
              "        [0.5160, 0.3900, 0.4318,  ..., 0.5855, 0.5350, 0.5176],\n",
              "        [0.8309, 0.5367, 0.7806,  ..., 0.5350, 0.8443, 0.8275],\n",
              "        [0.8378, 0.5495, 0.7602,  ..., 0.5176, 0.8275, 0.8200]],\n",
              "       device='cuda:0', grad_fn=<SigmoidBackward0>)"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "z = model.encode(test_data.x, test_data.edge_index) \n",
        "Ahat = torch.sigmoid(z @ z.T)\n",
        "Ahat"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## SEAL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tNNO18ePLd1W",
        "outputId": "c679fa31-3fbd-4971-c127-1079f13c9f04"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from sklearn.metrics import roc_auc_score, average_precision_score\n",
        "from scipy.sparse.csgraph import shortest_path\n",
        "\n",
        "import torch.nn.functional as F\n",
        "from torch.nn import Conv1d, MaxPool1d, Linear, Dropout, BCEWithLogitsLoss\n",
        "\n",
        "from torch_geometric.datasets import Planetoid\n",
        "from torch_geometric.transforms import RandomLinkSplit\n",
        "from torch_geometric.data import Data\n",
        "from torch_geometric.loader import DataLoader\n",
        "from torch_geometric.nn import GCNConv, aggr\n",
        "from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "XqlCeawwm0Pp"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Data(x=[2708, 1433], edge_index=[2, 8976], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], pos_edge_label=[4488], pos_edge_label_index=[2, 4488], neg_edge_label=[4488], neg_edge_label_index=[2, 4488])"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Load Cora dataset\n",
        "transform = RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True)\n",
        "dataset = Planetoid('.', name='Cora', transform=transform)\n",
        "train_data, val_data, test_data = dataset[0]\n",
        "train_data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "tCSfv3PSMQ-e"
      },
      "outputs": [],
      "source": [
        "def seal_processing(dataset, edge_label_index, y):\n",
        "    data_list = []\n",
        "\n",
        "    for src, dst in edge_label_index.t().tolist():\n",
        "        sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph([src, dst], 2, dataset.edge_index, relabel_nodes=True)\n",
        "        src, dst = mapping.tolist()\n",
        "\n",
        "        # Remove target link from the subgraph\n",
        "        mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)\n",
        "        mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)\n",
        "        sub_edge_index = sub_edge_index[:, mask1 & mask2]\n",
        "\n",
        "        # Double-radius node labeling (DRNL)\n",
        "        src, dst = (dst, src) if src > dst else (src, dst)\n",
        "        adj = to_scipy_sparse_matrix(sub_edge_index, num_nodes=sub_nodes.size(0)).tocsr()\n",
        "\n",
        "        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))\n",
        "        adj_wo_src = adj[idx, :][:, idx]\n",
        "\n",
        "        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))\n",
        "        adj_wo_dst = adj[idx, :][:, idx]\n",
        "\n",
        "        # Calculate the distance between every node and the source target node\n",
        "        d_src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)\n",
        "        d_src = np.insert(d_src, dst, 0, axis=0)\n",
        "        d_src = torch.from_numpy(d_src)\n",
        "\n",
        "        # Calculate the distance between every node and the destination target node\n",
        "        d_dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst-1)\n",
        "        d_dst = np.insert(d_dst, src, 0, axis=0)\n",
        "        d_dst = torch.from_numpy(d_dst)\n",
        "\n",
        "        # Calculate the label z for each node\n",
        "        dist = d_src + d_dst\n",
        "        z = 1 + torch.min(d_src, d_dst) + dist // 2 * (dist // 2 + dist % 2 - 1)\n",
        "        z[src], z[dst], z[torch.isnan(z)] = 1., 1., 0.\n",
        "        z = z.to(torch.long)\n",
        "\n",
        "        # Concatenate node features and one-hot encoded node labels (with a fixed number of classes)\n",
        "        node_labels = F.one_hot(z, num_classes=200).to(torch.float)\n",
        "        node_emb = dataset.x[sub_nodes]\n",
        "        node_x = torch.cat([node_emb, node_labels], dim=1)\n",
        "\n",
        "        # Create data object\n",
        "        data = Data(x=node_x, z=z, edge_index=sub_edge_index, y=y)\n",
        "        data_list.append(data)\n",
        "\n",
        "    return data_list"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1iPBfn1LLip-",
        "outputId": "afe486e3-ab37-4c05-c56b-d12af68fbc0e"
      },
      "outputs": [],
      "source": [
        "# Enclosing subgraphs extraction\n",
        "train_pos_data_list = seal_processing(train_data, train_data.pos_edge_label_index, 1)\n",
        "train_neg_data_list = seal_processing(train_data, train_data.neg_edge_label_index, 0)\n",
        "\n",
        "val_pos_data_list = seal_processing(val_data, val_data.pos_edge_label_index, 1)\n",
        "val_neg_data_list = seal_processing(val_data, val_data.neg_edge_label_index, 0)\n",
        "\n",
        "test_pos_data_list = seal_processing(test_data, test_data.pos_edge_label_index, 1)\n",
        "test_neg_data_list = seal_processing(test_data, test_data.neg_edge_label_index, 0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "dU_P2-JlR55j"
      },
      "outputs": [],
      "source": [
        "train_dataset = train_pos_data_list + train_neg_data_list\n",
        "val_dataset = val_pos_data_list + val_neg_data_list\n",
        "test_dataset = test_pos_data_list + test_neg_data_list\n",
        "\n",
        "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
        "val_loader = DataLoader(val_dataset, batch_size=32)\n",
        "test_loader = DataLoader(test_dataset, batch_size=32)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "OyuODuSqP6iu"
      },
      "outputs": [],
      "source": [
        "class DGCNN(torch.nn.Module):\n",
        "    def __init__(self, dim_in, k=30):\n",
        "        super().__init__()\n",
        "\n",
        "        # GCN layers\n",
        "        self.gcn1 = GCNConv(dim_in, 32)\n",
        "        self.gcn2 = GCNConv(32, 32)\n",
        "        self.gcn3 = GCNConv(32, 32)\n",
        "        self.gcn4 = GCNConv(32, 1)\n",
        "\n",
        "        # Global sort pooling\n",
        "        self.global_pool = aggr.SortAggregation(k=k)\n",
        "\n",
        "        # Convolutional layers\n",
        "        self.conv1 = Conv1d(1, 16, 97, 97)\n",
        "        self.conv2 = Conv1d(16, 32, 5, 1)\n",
        "        self.maxpool = MaxPool1d(2, 2)\n",
        "\n",
        "        # Dense layers\n",
        "        self.linear1 = Linear(352, 128)\n",
        "        self.dropout = Dropout(0.5)\n",
        "        self.linear2 = Linear(128, 1)\n",
        "\n",
        "    def forward(self, x, edge_index, batch):\n",
        "        # 1. Graph Convolutional Layers\n",
        "        h1 = self.gcn1(x, edge_index).tanh()\n",
        "        h2 = self.gcn2(h1, edge_index).tanh()\n",
        "        h3 = self.gcn3(h2, edge_index).tanh()\n",
        "        h4 = self.gcn4(h3, edge_index).tanh()\n",
        "        h = torch.cat([h1, h2, h3, h4], dim=-1)\n",
        "\n",
        "        # 2. Global sort pooling\n",
        "        h = self.global_pool(h, batch)\n",
        "\n",
        "        # 3. Traditional convolutional and dense layers\n",
        "        h = h.view(h.size(0), 1, h.size(-1))\n",
        "        h = self.conv1(h).relu()\n",
        "        h = self.maxpool(h)\n",
        "        h = self.conv2(h).relu()\n",
        "        h = h.view(h.size(0), -1)\n",
        "        h = self.linear1(h).relu()\n",
        "        h = self.dropout(h)\n",
        "        h = self.linear2(h).sigmoid()\n",
        "\n",
        "        return h"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 648
        },
        "id": "kcivDBP4PjDx",
        "outputId": "d733477a-0263-403d-fd1a-6ee3ec8950be"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch  0 | Loss: 0.6925 | Val AUC: 0.8215 | Val AP: 0.8357\n",
            "Epoch  1 | Loss: 0.6203 | Val AUC: 0.8543 | Val AP: 0.8712\n",
            "Epoch  2 | Loss: 0.5888 | Val AUC: 0.8783 | Val AP: 0.8877\n",
            "Epoch  3 | Loss: 0.5815 | Val AUC: 0.8825 | Val AP: 0.8920\n",
            "Epoch  4 | Loss: 0.5776 | Val AUC: 0.8917 | Val AP: 0.9025\n",
            "Epoch  5 | Loss: 0.5747 | Val AUC: 0.8971 | Val AP: 0.9079\n",
            "Epoch  6 | Loss: 0.5719 | Val AUC: 0.8932 | Val AP: 0.9027\n",
            "Epoch  7 | Loss: 0.5688 | Val AUC: 0.9064 | Val AP: 0.9119\n",
            "Epoch  8 | Loss: 0.5646 | Val AUC: 0.9007 | Val AP: 0.9112\n",
            "Epoch  9 | Loss: 0.5599 | Val AUC: 0.9064 | Val AP: 0.9072\n",
            "Epoch 10 | Loss: 0.5565 | Val AUC: 0.9037 | Val AP: 0.9043\n",
            "Epoch 11 | Loss: 0.5538 | Val AUC: 0.9036 | Val AP: 0.9046\n",
            "Epoch 12 | Loss: 0.5521 | Val AUC: 0.9038 | Val AP: 0.9051\n",
            "Epoch 13 | Loss: 0.5509 | Val AUC: 0.8997 | Val AP: 0.8980\n",
            "Epoch 14 | Loss: 0.5493 | Val AUC: 0.9018 | Val AP: 0.9027\n",
            "Epoch 15 | Loss: 0.5483 | Val AUC: 0.8985 | Val AP: 0.8959\n",
            "Epoch 16 | Loss: 0.5482 | Val AUC: 0.8990 | Val AP: 0.8963\n",
            "Epoch 17 | Loss: 0.5480 | Val AUC: 0.8991 | Val AP: 0.8955\n",
            "Epoch 18 | Loss: 0.5487 | Val AUC: 0.9022 | Val AP: 0.8984\n",
            "Epoch 19 | Loss: 0.5476 | Val AUC: 0.9035 | Val AP: 0.8990\n",
            "Epoch 20 | Loss: 0.5470 | Val AUC: 0.9011 | Val AP: 0.8984\n",
            "Epoch 21 | Loss: 0.5470 | Val AUC: 0.9003 | Val AP: 0.8980\n",
            "Epoch 22 | Loss: 0.5470 | Val AUC: 0.8957 | Val AP: 0.8950\n",
            "Epoch 23 | Loss: 0.5467 | Val AUC: 0.8944 | Val AP: 0.8954\n",
            "Epoch 24 | Loss: 0.5469 | Val AUC: 0.8919 | Val AP: 0.8946\n",
            "Epoch 25 | Loss: 0.5482 | Val AUC: 0.9016 | Val AP: 0.8989\n",
            "Epoch 26 | Loss: 0.5462 | Val AUC: 0.9014 | Val AP: 0.8994\n",
            "Epoch 27 | Loss: 0.5459 | Val AUC: 0.9000 | Val AP: 0.8992\n",
            "Epoch 28 | Loss: 0.5457 | Val AUC: 0.9054 | Val AP: 0.9018\n",
            "Epoch 29 | Loss: 0.5461 | Val AUC: 0.8991 | Val AP: 0.8973\n",
            "Epoch 30 | Loss: 0.5460 | Val AUC: 0.9005 | Val AP: 0.8992\n",
            "Test AUC: 0.8808 | Test AP 0.8863\n"
          ]
        }
      ],
      "source": [
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "model = DGCNN(train_dataset[0].num_features).to(device)\n",
        "optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)\n",
        "criterion = BCEWithLogitsLoss()\n",
        "\n",
        "def train():\n",
        "    model.train()\n",
        "    total_loss = 0\n",
        "\n",
        "    for data in train_loader:\n",
        "        data = data.to(device)\n",
        "        optimizer.zero_grad()\n",
        "        out = model(data.x, data.edge_index, data.batch)\n",
        "        loss = criterion(out.view(-1), data.y.to(torch.float))\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        total_loss += float(loss) * data.num_graphs\n",
        "\n",
        "    return total_loss / len(train_dataset)\n",
        "\n",
        "@torch.no_grad()\n",
        "def test(loader):\n",
        "    model.eval()\n",
        "    y_pred, y_true = [], []\n",
        "\n",
        "    for data in loader:\n",
        "        data = data.to(device)\n",
        "        out = model(data.x, data.edge_index, data.batch)\n",
        "        y_pred.append(out.view(-1).cpu())\n",
        "        y_true.append(data.y.view(-1).cpu().to(torch.float))\n",
        "\n",
        "    auc = roc_auc_score(torch.cat(y_true), torch.cat(y_pred))\n",
        "    ap = average_precision_score(torch.cat(y_true), torch.cat(y_pred))\n",
        "\n",
        "    return auc, ap\n",
        "\n",
        "for epoch in range(31):\n",
        "    loss = train()\n",
        "    val_auc, val_ap = test(val_loader)\n",
        "    print(f'Epoch {epoch:>2} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f} | Val AP: {val_ap:.4f}')\n",
        "\n",
        "test_auc, test_ap = test(test_loader)\n",
        "print(f'Test AUC: {test_auc:.4f} | Test AP {test_ap:.4f}')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "book",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.15"
    },
    "vscode": {
      "interpreter": {
        "hash": "3556630122da5213751af4465d61fcf5a52cd22515d400aee51118aaa1721248"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
